#include "KerasModelExport.h"
#include "model.h"

namespace keras2cpp
{
	KerasModel::KerasModel(const std::string& file_path)
		: keras_model_ptr_(new Model(file_path))
	{
	}

	KerasModel::~KerasModel()
	{
		delete keras_model_ptr_;
	}

	std::vector<float> KerasModel::predict(const std::vector<float>& input_param_vec)
	{
		Tensor tensor{ input_param_vec.size() };
		for (int i = 0; i < input_param_vec.size(); i++)
			tensor.data_[i] = input_param_vec[i];
		return (*keras_model_ptr_)(tensor).data_;
	}
}